#pragma once

#include <torch/extension.h>

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> p2e_logproba_backward(
    torch::Tensor grad_output,
    torch::Tensor sxy,
    torch::Tensor oxy,
    torch::Tensor invcov,
    torch::Tensor logdet_invcov,
    torch::Tensor fids,
    torch::Tensor max_idx
);
